from diffgro.environments.metaworld.policies.policy import HeuristicPolicy
from diffgro.environments.metaworld.policies.skill import *


def inflate_policy(waypoints):
    class KinematicV2Policy(HeuristicPolicy):
        def _initialize_skill(self, _):
            self.skills = [get_skill(waypoint) for waypoint in waypoints]

        @staticmethod
        def _parse_obs(obs):
            return {"hand_pos": obs[:3]}

    return KinematicV2Policy()


def get_skill(waypoint):
    skill_name, *target_value = waypoint
    grab_effort = 0

    if len(target_value) == 1:
        target_value = target_value[0]
    else:
        target_value, grab_effort = target_value

    skill_class = {
        "MOVE_X": MoveXSkill,
        "MOVE_Y": MoveYSkill,
        "MOVE_Z": MoveZSkill,
        "GRIP": GrabSkill,
        "PUSH": PushSkill,
        "PULL": PullSkill,
    }[skill_name]

    if skill_class.__name__ == "GrabSkill":
        return skill_class(grab_effort=1)

    return skill_class(target_value, atol=0.005, p=8.0, grab_effort=grab_effort)
